{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial on Hadamard Test Based Gradient Estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Up\n",
    "\n",
    "We define three types of example circuits in ./circ.py and here we use Circ1 as an example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hadamard_grad import hadamard_grad\n",
    "from circ import Circ1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Circuit\n",
    "\n",
    "We need to run the circuit first to record its operations in the q_device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expval:\n",
      "tensor([0.8776], grad_fn=<SelectBackward0>)\n",
      "op_history:\n",
      "[{'name': 'OpPauliExp', 'wires': [0, 1, 2, 3], 'coeffs': [1.0], 'paulis': ['YXIX'], 'inverse': False, 'trainable': True, 'params': 0.5}]\n"
     ]
    }
   ],
   "source": [
    "circ1 = Circ1()\n",
    "expval1, qdev1 = circ1()\n",
    "print('expval:')\n",
    "print(expval1)\n",
    "print('op_history:')\n",
    "print(qdev1.op_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gradient Estimation\n",
    "\n",
    "To facilitate torchquantum v0.1.7, we need to manually extrac the following three information from the q_device and circuit:\n",
    "\n",
    "- op_history: The history of quantum operations applied on the q_device.\n",
    "- n_wires: The number of wires (quantum bits) in the q_device.\n",
    "- observable: The observable of the original circuit, for which the gradients are computed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "op_history = qdev1.op_history\n",
    "n_wires = qdev1.n_wires\n",
    "observable = 'ZZZZ'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/vast/palmer/home.grace/dl2276/AllProjects/torchquantum/torchquantum/functional/func_controlled_unitary.py:99: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  unitary = torch.tensor(torch.zeros(2**n_wires, 2**n_wires, dtype=C_DTYPE))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[tensor(-0.4794)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hadamard_grad(op_history, n_wires, observable)\n",
    "# -0.47942554"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.11 ('tq')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "0e29a0e2839e50afa2f3a774fbdc59fa5f031cf090cdc0c9e6a9a24240713eb0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
